package com.tiantian.framework.handler;

import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.ObjectUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
@Component
public class NoticeWebSocketHandler extends TextWebSocketHandler {

    // 静态变量，用来记录当前在线连接数。应该把它设计成线程安全的。
    public static final AtomicInteger onlineNum = new AtomicInteger();
    // concurrent包的线程安全Set，用来存放每个客户端对应的WebSocketServer对象。
    private static final ConcurrentHashMap<String, WebSocketSession> sessionPools = new ConcurrentHashMap<>();


    /**
     * 在连接建立后的操作
     *
     * @param session WebSocketSession
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        // 获取到拦截器中用户ID
        String userId = getUserId(session);

        // 对重复连接进行处理 (判断map中是否有指定的键) 并且 只对新人用户做统计
        if (ObjectUtil.isEmpty(sessionPools.get(userId))) {
            // 在线人数 + 1
            addOnlineCount();
            sessionPools.put(userId, session);
        } else {
            // 无论是否多次连接都存入会话中
            String key = "repeat:" + IdUtil.fastSimpleUUID() + userId;
            sessionPools.put(key, session);
        }
        log.info("连接的用户ID为：" + onlineNum + "当前人数为：" + onlineNum);
    }

    private String getUserId(WebSocketSession session) {
        return session.getAttributes().get("userId").toString();
    }


    /**
     * 群发消息
     *
     * @param messageBody 需要发送的消息(SysNoticeDTO的JsonString格式)
     */
    public void sendToAllClient(String messageBody) {
        // 获取到所有会话
        Collection<WebSocketSession> sessions = sessionPools.values();
        for (WebSocketSession session : sessions) {
            try {
                // 服务器向客户端发送消息
                session.sendMessage(new TextMessage(messageBody));
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 接收消息
     *
     * @param session WebSocketSession
     * @param message TextMessage
     */
    @Override
    @Deprecated // 用不上
    public void handleTextMessage(WebSocketSession session, TextMessage message) throws IOException {
        System.out.println("获取到消息 >> " + message.getPayload());
        session.sendMessage(new TextMessage(String.format("收到用户：【%s】发来的【%s】", getUserId(session), message.getPayload())));
    }


    /**
     * 连接关闭后对应的操作
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        String userId = getUserId(session);
        sessionPools.remove(userId);
        // 在线人数 - 1
        subOnlineCount();
    }

    /**
     * 添加链接人数
     */
    public static void addOnlineCount() {
        onlineNum.incrementAndGet();
    }

    /**
     * 移除链接人数
     */
    public static void subOnlineCount() {
        onlineNum.decrementAndGet();
    }

}
